# ------------------------------------------------------------------------------
# Random effects of shift variables on test rate
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 03_shift_random_effects.sh
# ------------------------------------------------------------------------------

# Setup ------------------------------------------------------------------------
set.seed(504)

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings
library(broom) # tidy regression results
library(lme4) # lmer linear random effects model
library(RLRsim) # exactLRT

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""
u <- modules::use(here::here("lib", "util.R"))

# Helper Functions -------------------------------------------------------------
get_mod_label <- function(shift_bin, sample, outcome, incl_yhat, yhat_term) {
  print(yhat_term)
  yhat_termlab <- case_when(
    grepl("tile", yhat_term) ~ "deciles",
    TRUE ~ "linear"
  )
  yhat_lab <- ifelse(incl_yhat, glue("_inclyhat_{yhat_termlab}"), "")
  model_label <- glue("re_{shift_bin}_{sample}_{outcome}{yhat_lab}")
  return(model_label)
}

get_shift_re <- function(
  shift_bin, sample, outcome, incl_yhat, yhat_term, lab, ...
) {
  print(lab)
  possible_samples <- c("full", "train", "test")
  assert("sample in possible_samples", sample %in% possible_samples)
  if (sample == "train") {
    dt <- filter(DT, split == "train")
  } else if (sample == "test") {
    dt <- filter(DT, split == "test")
  } else {
    dt <- DT
  }

  FEs <- c(0, "hour", "wday", "week", "year")
  if (incl_yhat) {
    dt <- dt %>%
      filter(split != "val")
    FEs <- append(
      FEs,
      yhat_term
    )
  }
  covars_full <- c(FEs, glue("(1 | {shift_bin})"))
  form <- reformulate(response = outcome, termlabels = covars_full)
  re_fit <- lmer(form, data = dt, REML = FALSE
  )

  form_reduced <- reformulate(response = outcome, termlabels = FEs)
  re_fit_basic <- lm(form_reduced, data = dt)

  message("LRT")
  likelihood_diff = -2*logLik(re_fit_basic) + 2*logLik(re_fit)
  message("Likelihood difference for ", lab, ": ", likelihood_diff)
  p_val <- pchisq(as.numeric(likelihood_diff), df = 1, lower.tail = F)
  message("LRT p-value for ", lab, ": ", p_val)

  message("RLRsim")
  sim_result <- exactLRT(re_fit, re_fit_basic)
  message("Likelihood difference for ", lab, ": ", sim_result$statistic)
  message("LRT p-value for ", lab, ": ", sim_result$p)

  print(sim_result)

  write_rds(re_fit, file.path(temp, glue("fit_{lab}.rds")))
  write_rds(re_fit_basic, file.path(temp, glue("basic_fit_{lab}.rds")))
  re_fit <- readRDS(file.path(temp, glue("fit_{lab}.rds")))
  return(re_fit)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))

DT <- readRDS(file.path(temp, "shift_test_rates.rds")) %>%
  filter(split == "train" | split == "test")

# Models  ----------------------------------------------------------------------
message("Fitting RE models...")
config <- crossing(
  shift_bin = c("shift_12", "shift_08"),
  sample = c("full"),
  outcome = c("test_010_day"),
  incl_yhat = c(TRUE),
  yhat_term = c(
    "p__ensemble__stent_or_cabg_010_day__tested"
  )
)
re_models <- config %>%
  mutate(lab = pmap(., get_mod_label)) %>%
  mutate(model = pmap(., get_shift_re))

# Save -------------------------------------------------------------------------
save_fp <- file.path(temp, "re_models.rds")
message("Saving RE models to ", save_fp, "...")
write_rds(re_models, save_fp)

message("Done.")
